import numpy as np
import matplotlib.pyplot as plt


def py_cpu_nms(nms_show, dets, thresh):# nms操作
    "Pure Python NMS baseline"
    # x1、y1、x2、y2以及score赋值
    x1 = dets[:,0]
    y1 = dets[:,1]
    x2 = dets[:,2]
    y2 = dets[:,3]
    scores = dets[:, 4]
    areas = (y2-y1+1) * (x2-x1+1) #每一个检测框的面积
    print("all_areas:",areas)
    order = scores.argsort()[::-1]# 按照score置信度降序排序
    print("order:",order)
    keep = [] # 保留的结果框集合
    k=0
    show_boxid = order
    print("show_boxid",show_boxid)
    while order.size > 0:
        i = order[0]       # every time the first is the biggst, and add it directly
        keep.append(i) # 保留该类剩余box中得分最高的一个
        print("\nkeep(被留下来的框id):",keep)
        # 得到相交区域,比左上大和比右下小
        '''np.maximum(X, Y, out=None) X和Y逐位进行比较,选择最大值'''
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
        print("xx1,yy1:",xx1,yy1)
        print("xx2,yy2:",xx2,yy2)
        print("w,h:",xx2-xx1+1,yy2-yy1+1)
        #plt.scatter([xx1,xx2], [yy1,yy2], s=50, c='y')#框左上角蓝色
        # 计算相交的面积,不重叠时面积为0
        w = np.maximum(0, xx2-xx1+1)    # the weights of overlap
        h = np.maximum(0, yy2-yy1+1)    # the height of overlap
        inter = w*h
        print("w*h:",inter)
        # 计算IoU：重叠面积 /（面积1+面积2-重叠面积）
        iou = inter / (areas[i]+areas[order[1:]] - inter)
        print("---iou:",iou)
        # 保留IoU小于阈值的box
        '''去掉keep剩下的框按顺序重新排序'''
        indx = np.where(iou<=thresh)[0]
        dedx = np.where(iou>thresh)[0]
        print("order[0]:",order[0])
        show_boxid = np.delete(show_boxid,np.where(show_boxid==order[0])[0],axis = 0)
        print("show_boxid:",np.append(keep,show_boxid))
        delete_boxid = show_boxid[dedx]
        show_boxid = show_boxid[indx]
        #print("indx:",indx)
        print("after_iou_show_boxid:",np.append(keep,show_boxid))
        print("after_iou_delete_boxid:",delete_boxid,'\n')
        k=k+1
        '''绘制动图'''
        ax1 = nms_show.add_subplot(1,2,1)
        ax1.set_title('begin_nms {} \nkeep:{} score:{}({})'.format(k,order[0],scores[order[0]],k))
        plot_bbox(dets, 'k', show_ids=np.arange(9) , keep_id = order[0])
        ax2 = nms_show.add_subplot(1,2,2)
        ax2.set_title('after_nms {} \nkeep:{} delete:{}(iou={})'.format(k,order[0],delete_boxid,iou[dedx]))
        plot_bbox(dets[np.append(show_boxid,keep)], 'b', np.append(show_boxid,keep))
        plt.pause(5)
        ax1.remove()
        ax2.remove()
        '''置信度排前的数值给取出，剩下的数构成新的数组'''
        order = order[indx+1]
        print("-----------afer_order:",order,'-----------')
    return keep


def plot_bbox(dets, c='k', show_ids=[],keep_id=0):
    x1 = dets[:,0]
    y1 = dets[:,1]
    x2 = dets[:,2]
    y2 = dets[:,3]
    score=dets[:,4]
    print(dets.shape)
    plt.scatter(x1, y1, s=25, c='b', alpha=0.6)#框左上角蓝色
    plt.scatter(x2, y2, s=25, c='r', alpha=0.6)#框右下角红色
    plt.plot([x1,x2], [y1,y1], c)
    plt.plot([x1,x1], [y1,y2], c)
    plt.plot([x1,x2], [y2,y2], c)
    plt.plot([x2,x2], [y1,y2], c)
    plt.xlim((60,450))
    plt.ylim((450,60))
    '''改变坐标轴位置'''
    ax = plt.gca()
    ax.spines["top"].set_color("k")
    ax.xaxis.set_ticks_position("top")
    for i in range(len(show_ids)):
        plt.text(x1[i], y1[i]+7, "(%d)%.2f"%(show_ids[i],score[i]), \
                 fontdict={'size': 10, 'color': 'r'},bbox={'facecolor':'blue', 'alpha':0.1})
    if keep_id != 0:
        ax.add_patch(plt.Rectangle((x1[keep_id], y1[keep_id]), x2[keep_id]-x1[keep_id]+1, y2[keep_id]-y1[keep_id]+1,
                                   color="y", fill=True, linewidth=2))


def main():
    boxes=np.array([
            [100,100,210,210,0.72],#0
            [280,290,420,420,0.8],#1
            [220,220,320,330,0.92],#2
            [105,90,220,210,0.71],#3
            [230,240,325,330,0.81],#4
            [305,300,420,420,0.9],#5
            [215,225,305,328,0.6],#6
            [150,260,290,400,0.99],#7
            [102,108,208,208,0.72]])#8  #9个框
    plt.ion()
    fig = plt.figure(figsize=[14,9])
    ax1 = plt.subplot(1,2,1)
    ax1.set_title('before_nms')
    ax2 = plt.subplot(1,2,2)
    ax2.set_title('after_nms')
    plt.sca(ax1)# 选择子图1
    plot_bbox(boxes,'k',show_ids=np.arange(9),keep_id=0)   # before nms
    keep = py_cpu_nms(fig, boxes, thresh=0.7)
    print("last_keep:",keep)
    plt.ioff()
    plt.pause(2)
    plt.close('all')

if __name__ =="__main__":
    main()